package neuroph;

import cn.hutool.core.date.DateUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.ybg.ShareApplication;
import com.ybg.share.core.dbapi.entity.ShareStockDayK;
import com.ybg.share.core.dbapi.service.ShareStockDayKService;
import com.ybg.share.framework.neural.listener.RBFListener;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.RBFNetwork;
import org.neuroph.nnet.learning.RBFLearning;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;

import java.util.List;

@SpringBootTest(classes = ShareApplication.class)
@RunWith(SpringRunner.class)
public class StockDayK {
    @Autowired
    ShareStockDayKService shareStockDayKService;


    @Test
    public void indexGenNeuroph() {

        try {
            Thread.sleep(3000);
            QueryWrapper<ShareStockDayK> wrapper = new QueryWrapper<>();

            wrapper.eq(ShareStockDayK.STOCK_ID,1L);
         //   wrapper.gt(ShareStockDayK.DATE, "2008-01-01");
         //   wrapper.orderByAsc(ShareStockIndex.DATE);
            System.out.println("查询股");
            List<ShareStockDayK> list = shareStockDayKService.list(wrapper);
            System.out.println("数据量:"+list.size());

            NeuralNetwork neuralNetwork = new RBFNetwork(2, 40, 1);
            DataSet trainingSet = new DataSet(2, 1);
            for (ShareStockDayK shareStockIndex : list) {

                //    System.out.println(shareStockIndex.getChangeAmount().doubleValue());
                double[] input= new double[]{//shareStockIndex.getOpenPrice().doubleValue(),shareStockIndex.getClosePrice().doubleValue(),
                        shareStockIndex.getBeforeClose().doubleValue(),shareStockIndex.getOpenPrice().doubleValue()};
                double[] output= new double[]{
//                        shareStockIndex.getOpenPrice().doubleValue()
//                        ,shareStockIndex.getClosePrice().doubleValue(),
//                        shareStockIndex.getTradeNum().doubleValue()
//                        ,shareStockIndex.getTradeMoney().doubleValue()};
                        shareStockIndex.getChangeRange().doubleValue()
                };
                trainingSet.add(new DataSetRow( input,output ));


            }
            RBFLearning learningRule = ((RBFLearning) neuralNetwork.getLearningRule());
            System.out.println("默认学习率"+learningRule.getLearningRate());
            System.out.println("默认最大错误数"+learningRule.getMaxError());
            learningRule.setLearningRate(0.00001);
           learningRule.setMaxError(3000);
            learningRule.addListener(new RBFListener());
            System.out.println("循环结束" + DateUtil.now());
            neuralNetwork.learn(trainingSet);
            System.out.println("学习结束" + DateUtil.now());
            //  neuralNetwork.save("/test/shareText1.nnet");
            System.out.println("保存结束" + DateUtil.now());
            neuralNetwork.setInput(14.67, 14.51);
            neuralNetwork.calculate();
            System.out.println("准备输出结果" + DateUtil.now());
            double[] networkOutput = neuralNetwork.getOutput();
            System.out.println("结果=" + networkOutput[0] );
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
